import torch

from abc import ABC, abstractmethod

class ContinuousEnvironment(ABC):
    """
    Base class for continuous environments, where actions represent the transition vector from the current state to the next state.
    """
    def __init__(self, config, dim, feature_dim, angle_dim, action_dim, lower_bound, upper_bound, mixture_dim, output_dim):
        self.config = config
        self.dim = dim                 # Dimension of the state space
        self.feature_dim = feature_dim         # Dimension of the feature space (input to the policy network)
        self.angle_dim = angle_dim     # Boolean list, indicating which dimensions are angles
        self.action_dim = action_dim   # Dimension of the action space
        self.reward_Z = None
        self.device = torch.device(config["device"])

        required_params = ["mixture_dim", "num_grid_points", "init_value"]
        assert all(param in self.config["env"] for param in required_params), "Base environment missing required parameters: {}".format(required_params)

        self.lower_bound = torch.tensor(lower_bound, device=self.device) # Lower bounds of the state space, shape (dim,)
        self.upper_bound = torch.tensor(upper_bound, device=self.device) # Upper bounds of the state space, shape (dim,)
        self.mixture_dim = mixture_dim # Number of mixture components in the parameterisation of the policy
        self.output_dim = output_dim   # Dimension of the output of the policy network
        self.set_num_grid_points() # Number of grid points in the state space in each dimension, shape (dim,)
        if not torch.is_tensor(self.num_grid_points) or self.num_grid_points.ndimension() != 1:
            raise ValueError("num_grid_points must be a 1-dimensional array")
        assert len(self.num_grid_points) == self.dim, "num_grid_points must be specified for each dimension of the state space"

        self.set_init_value()

        # State space grid
        self.grid_spacing = (self.upper_bound - self.lower_bound) / (self.num_grid_points - 1) # Spacing between grid points in each dimension
        self.grid = torch.stack(torch.meshgrid([torch.linspace(self.lower_bound[i], self.upper_bound[i], self.num_grid_points[i], device=self.device) for i in range(self.dim)]), dim=-1)
        self.marginal_grid = tuple([torch.linspace(self.lower_bound[i], self.upper_bound[i], self.num_grid_points[i], device=self.device) for i in range(self.dim)])
        self.grid_bins = tuple([torch.linspace(self.lower_bound[i], self.upper_bound[i], self.num_grid_points[i] + 1, device=self.device) for i in range(self.dim)])

        self.reward_grid, self.target_density, self.reward_Z = self.get_reward_grid_density_and_Z()

    def set_num_grid_points(self):
        num_grid_points = self.config["env"]["num_grid_points"]
        if isinstance(num_grid_points, int):
            self.num_grid_points = torch.full((self.dim,), num_grid_points, device=self.device)
        else:
            self.num_grid_points = torch.tensor(num_grid_points, device=self.device)
        
    def set_init_value(self):    
        init_value = self.config["env"]["init_value"]

        if isinstance(init_value, float):  # Check if init_value is a float
            self.init_value = torch.full((self.dim,), init_value, device=self.device)  # Create a tensor with the same value in all dimensions
        else:
            self.init_value = torch.tensor(init_value, device=self.device)  # Assume init_value is iterable and convert to tensor

        assert torch.all(self.init_value >= self.lower_bound) and torch.all(self.init_value <= self.upper_bound), "init_value must be within the bounds of the environment"

    def get_reward_grid_density_and_Z(self):
        """Returns the reward grid R(x), the normalisation constant Z and reward density measure R(x)dx/Z."""
        reward_grid = self.log_reward(self.grid.clone()).exp()
        target_density = reward_grid * torch.prod(self.grid_spacing)    # unnormalised 
        reward_Z = target_density.sum()
        target_density = target_density / reward_Z                      # normalised

        if self.dim == 1:
            target_density = target_density.T

        return reward_grid, target_density, reward_Z
        
    def get_onpolicy_dist(self, on_policy_samples):
        """Returns the empirical distribution of the on-policy samples on a grid. Used for plotting and error calculations."""
        number_of_samples = on_policy_samples.shape[0]
        terminal_states = on_policy_samples[:, -1, :-1] 
        terminal_states = terminal_states.to("cpu") # Histogramdd only works on CPU
        cpu_grid_bins = tuple([bin_data.cpu() for bin_data in self.grid_bins])
        empirical_distribution, _ = torch.histogramdd(terminal_states, bins=(cpu_grid_bins))
        empirical_distribution = empirical_distribution / number_of_samples

        return empirical_distribution.to(self.device)
    
    def get_policy_dist(self, params):
        """Initialises the policy distribution given the parameters of the policy."""
        param_dict = self.postprocess_params(params)
        policy_dist = self._init_policy_dist(param_dict)

        return policy_dist
    
    def get_exploration_dist(self, params, off_policy_noise):
        """Initialises the exploration distribution given the parameters of the policy and the off-policy noise."""
        param_dict = self.postprocess_params(params)
        exploration_dist = self._init_policy_dist(self.add_noise(param_dict, off_policy_noise))

        return exploration_dist
    
    def featurisation(self, states):
        """Given a batch of states, returns the feature representation of the states."""
        return states

    @abstractmethod
    def log_reward(self, x):
        """Given a terminating state x, returns the log reward of that state."""
    
    @abstractmethod
    def step(self, x, action):
        """Given a state x and an action, returns the next state."""
    
    @abstractmethod
    def backward_step(self, x, action):
        """Given a state x and an action, returns the previous state."""

    @abstractmethod
    def compute_initial_action(self, first_state):
        """Given the first state of the trajectory, returns the initial action that was taken to reach that state from the initial state."""
    
    @abstractmethod
    def _init_policy_dist(self, param_dict: dict):
        """Returns the policy distribution given the parameters of the policy."""
    
    @abstractmethod
    def postprocess_params(self, params: torch.Tensor):
        """Post-processes the parameters of the policy to ensure they are within specified bounds."""

    @abstractmethod
    def add_noise(self, param_dict: dict, off_policy_noise: float):
        """Adds noise to the parameters of the policy."""

    